import dalex as dx
import pandas as pd
import pickle
from sklearn.model_selection import train_test_split
path_to_data = '../../PracaDomowa3/Sawicki_Bartosz/new_preprocessed_dataset.csv'
input_df = pd.read_csv(path_to_data)
y = input_df.loc[:,'Attrition']
X = input_df.drop('Attrition', axis='columns')
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=14)
path_to_models = '../../../Projekt/Modele/BarteKasiAdam/'
xgb = pickle.load(open(path_to_models + 'new_xgb_model.p', "rb" ))
exp_xgb = dx.Explainer(xgb, X_train, y_train, label='XGB')
rf = pickle.load(open(path_to_models + 'new_random_forest_model.p', "rb" ))
exp_rf = dx.Explainer(rf, X_train, y_train, label='RandomForest')
reg = pickle.load(open(path_to_models + 'l1_log_reg.p', "rb" ))
exp_reg = dx.Explainer(reg, X_train, y_train, label='LogisticRegression')
Preparation of a new explainer is initiated -> data : 7595 rows 21 cols -> target variable : Parameter 'y' was a pandas.Series. Converted to a numpy.ndarray. -> target variable : 7595 values -> model_class : xgboost.sklearn.XGBClassifier (default) -> label : XGB -> predict function : <function yhat_proba_default at 0x7f20689ba820> will be used (default) -> predict function : Accepts pandas.DataFrame and numpy.ndarray. -> predicted values : min = 1.49e-06, mean = 0.159, max = 1.0 -> model type : classification will be used (default) -> residual function : difference between y and yhat (default) -> residuals : min = -0.699, mean = 9.19e-05, max = 0.82 -> model_info : package xgboost A new explainer has been created!
/home/sawcio/Studia/4sem/Warsztaty_badawcze/wb-env/lib/python3.8/site-packages/sklearn/base.py:310: UserWarning: Trying to unpickle estimator DecisionTreeClassifier from version 0.23.2 when using version 0.24.1. This might lead to breaking code or invalid results. Use at your own risk. warnings.warn( /home/sawcio/Studia/4sem/Warsztaty_badawcze/wb-env/lib/python3.8/site-packages/sklearn/base.py:310: UserWarning: Trying to unpickle estimator RandomForestClassifier from version 0.23.2 when using version 0.24.1. This might lead to breaking code or invalid results. Use at your own risk. warnings.warn(
Preparation of a new explainer is initiated -> data : 7595 rows 21 cols -> target variable : Parameter 'y' was a pandas.Series. Converted to a numpy.ndarray. -> target variable : 7595 values -> model_class : sklearn.ensemble._forest.RandomForestClassifier (default) -> label : RandomForest -> predict function : <function yhat_proba_default at 0x7f20689ba820> will be used (default) -> predict function : Accepts pandas.DataFrame and numpy.ndarray. -> predicted values : min = 0.0, mean = 0.159, max = 1.0 -> model type : classification will be used (default) -> residual function : difference between y and yhat (default) -> residuals : min = -0.37, mean = 0.000521, max = 0.44 -> model_info : package sklearn A new explainer has been created! Preparation of a new explainer is initiated -> data : 7595 rows 21 cols -> target variable : Parameter 'y' was a pandas.Series. Converted to a numpy.ndarray. -> target variable : 7595 values -> model_class : sklearn.linear_model._logistic.LogisticRegression (default) -> label : LogisticRegression -> predict function : <function yhat_proba_default at 0x7f20689ba820> will be used (default) -> predict function : Accepts pandas.DataFrame and numpy.ndarray. -> predicted values : min = 4.17e-06, mean = 0.159, max = 0.984 -> model type : classification will be used (default) -> residual function : difference between y and yhat (default) -> residuals : min = -0.879, mean = 0.000387, max = 0.999 -> model_info : package sklearn A new explainer has been created!
pdp_xgb = exp_xgb.model_profile(random_state=14)
pdp_rf = exp_rf.model_profile(random_state=14)
pdp_reg = exp_reg.model_profile(random_state=14)
Calculating ceteris paribus: 100%|██████████| 21/21 [00:00<00:00, 24.16it/s] Calculating ceteris paribus: 100%|██████████| 21/21 [00:04<00:00, 4.85it/s] Calculating ceteris paribus: 100%|██████████| 21/21 [00:00<00:00, 50.65it/s]
pdp_xgb.plot([pdp_rf, pdp_reg])
Wykres dla wielu zmiennych jest prostą równoległą do OX na poziomie średniej predykcji modeli. Oznacza to, że nie są one istotne w procesie podejmowania decyzji.
Dla lepszej czytelności wybierzmy zmienne, które mają ciekawe wykresy.
var_names = ['Total_Ct_Chng_Q4_Q1', 'Total_Revolving_Bal',
'Total_Amt_Chng_Q4_Q1', 'Contacts_Count_12_mon',
'Total_Trans_Amt']
pdp_xgb = exp_xgb.model_profile(random_state=14, variables=var_names)
pdp_rf = exp_rf.model_profile(random_state=14, variables=var_names)
pdp_reg = exp_reg.model_profile(random_state=14, variables=var_names)
Calculating ceteris paribus: 100%|██████████| 5/5 [00:00<00:00, 11.02it/s] Calculating ceteris paribus: 100%|██████████| 5/5 [00:01<00:00, 4.59it/s] Calculating ceteris paribus: 100%|██████████| 5/5 [00:00<00:00, 54.90it/s]
pdp_xgb.plot([pdp_rf, pdp_reg])
Total_Trans_Amt (całkowita wartość transakcji) i Total_Revolving_Bal (całkowita wartość przeniesiona na następny okres rozliczeniowy) mają największy wpływ na wynik predykcji, co potwierdza hipotezy stawiane od EDA, przez wyjaśnienia lokalne do analizy permutacyjnej ważności zmiennych. Total_Ct_Chng_Q4_Q1 i Total_Amt_Chng_Q4_Q1 (Względna zmiana odpowiednio liczby i wielkości transakcji od 4. do 1. kwartału) sprzyja odejściu z banku. Co ciekawe, Total_Amt_Chng_Q4_Q1 w modelu regresji logistycznej nie wpływa na predykcję. Widzimy też, że zmiana liczby transakcji ma większe znaczenie niż zmiana w wielkości transakcji. Jednak te zmienne są ze sobą skorelowane, więc profile PDP mogą nie oddawać rzeczywistej sytuacji. Sprawdzimy wpływ tych zmiennych za pomocą profili ALE w następnym punkcie.Contacts_Count_12_mon (liczba kontaktów klienta z przedstawicielami banku w ostatnich 12 miesiącach) rośnie bardzo wyraźnie, gdy zbliża się do 6. Intuicyjnie można wytłumaczyć "Klient był niezadowolony z usługi, szukał alternatywy lub poprawy sytuacji w banku i niemogąc jej znaleźć decydował się na rezygnację z karty". Pozostaje jednak pytanie, dlaczego modele nie reagują gdy Contacts_Count_12_mon <= 5 (wtedy prawie stale osiągają poziom średniej predykcji). Postanowiłem sprawdzić co szczególnego jest w 6 wizycie w banku.print('Odsetek obserwacji w zbiorze, gdzie Contacts_Count_12_mon > 5: {:2.2%}'.
format(len(X_train[X_train['Contacts_Count_12_mon']>5])/len(X_train)))
print('Odsetek obserwacji, gdzie klient odszedł, wśród obserwacji gdzie Contacts_Count_12_mon > 5 : {:2.2%}'.
format(len(X_train[(X_train['Contacts_Count_12_mon']>5) & (y_train==1)])/
len(X_train[X_train['Contacts_Count_12_mon']>5])))
Odsetek obserwacji w zbiorze, gdzie Contacts_Count_12_mon > 5: 0.58% Odsetek obserwacji, gdzie klient odszedł, wśród obserwacji gdzie Contacts_Count_12_mon > 5 : 100.00%
Okazało się, że jest stosunkowo niewielu klientów, którzy kontaktowali się z bankiem więcej niż 5 razy oraz że wszyscy z nich zrezygnowali z karty. Uważam, że są dwie możliwości:
ale_xgb = exp_xgb.model_profile(type = 'accumulated', random_state=14)
ale_rf = exp_rf.model_profile(type = 'accumulated', random_state=14)
ale_reg = exp_reg.model_profile(type = 'accumulated', random_state=14)
Calculating ceteris paribus: 100%|██████████| 21/21 [00:01<00:00, 12.65it/s] Calculating accumulated dependency: 100%|██████████| 21/21 [00:02<00:00, 7.41it/s] Calculating ceteris paribus: 100%|██████████| 21/21 [00:04<00:00, 4.57it/s] Calculating accumulated dependency: 100%|██████████| 21/21 [00:02<00:00, 9.87it/s] Calculating ceteris paribus: 100%|██████████| 21/21 [00:00<00:00, 36.90it/s] Calculating accumulated dependency: 100%|██████████| 21/21 [00:02<00:00, 7.52it/s]
ale_xgb.plot([ale_rf, ale_reg])
W przypadku profili ALE większość wykresów także nie wnosi żadnej informacji. Dla uproszczenia wygenerujemy tylko istotne wizualizacje.
ale_xgb = exp_xgb.model_profile(type = 'accumulated', random_state=14, variables=var_names)
ale_rf = exp_rf.model_profile(type = 'accumulated', random_state=14, variables=var_names)
ale_reg = exp_reg.model_profile(type = 'accumulated', random_state=14, variables=var_names)
Calculating ceteris paribus: 100%|██████████| 5/5 [00:00<00:00, 13.52it/s] Calculating accumulated dependency: 100%|██████████| 5/5 [00:00<00:00, 9.50it/s] Calculating ceteris paribus: 100%|██████████| 5/5 [00:00<00:00, 5.67it/s] Calculating accumulated dependency: 100%|██████████| 5/5 [00:00<00:00, 9.73it/s] Calculating ceteris paribus: 100%|██████████| 5/5 [00:00<00:00, 65.22it/s] Calculating accumulated dependency: 100%|██████████| 5/5 [00:00<00:00, 9.33it/s]
ale_xgb.plot([ale_rf, ale_reg])
Total_Ct_Chng_Q4_Q1 i Total_Amt_Chng_Q4_Q1 okazały się prawidłowe, gdy liczba transakcji spadnie bardziej niż o 50% w kolejnych okresach, istnieje duża szansa, że klient zrezygnuje z karty. Ważniejsza jest jednak zmiana w liczbie transakcji, a nie w ich wielkości.Contact_Count_12_mon ma podoby charakter jak w przypadku profili PDP. Obserwujemy nagły wzrost predykcji, gdy zmienna ma wartość 6.Porównując analizy dwoma metodami, możemy sprawdzić czy model jest addytywny
ale_xgb.result['_label_'] = "ALE_XGB"
pdp_xgb.result['_label_'] = "PDP_XGB"
ale_xgb.plot(pdp_xgb)
Wykresy są równoległe, więc model nie wykrywa interakcji między zmiennymi
ale_rf.result['_label_'] = "ALE_RF"
pdp_rf.result['_label_'] = "PDP_RF"
ale_rf.plot(pdp_rf)
W tym przypadku wykresy również są równoległe, zatem model Random Forest, przynajmniej dla najważniejszych zmiennych, jest addytywny.
Tu spodziwanym wynikiem jest addytywność modelu z definicji regresji logistycznej.
ale_reg.result['_label_'] = "ALE_REG"
pdp_reg.result['_label_'] = "PDP_REG"
ale_reg.plot(pdp_reg)
Zgodnie z przewidywaniami model regresji liniowej jest addytywny.